{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as T\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchaudio\n",
    "from utils import load_ckpt, print_colored\n",
    "from tokenizer import make_tokenizer\n",
    "from model import get_hertz_dev_config\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import Audio, display\n",
    "\n",
    "\n",
    "# If you get an error like \"undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12\",\n",
    "# you need to install PyTorch with the correct CUDA version. Run:\n",
    "# `pip3 uninstall torch torchaudio && pip3 install torch torchaudio --index-url https://download.pytorch.org/whl/cu121`\n",
    "\n",
    "device = 'cuda' if T.cuda.is_available() else 'cpu'\n",
    "T.cuda.set_device(0)\n",
    "print_colored(f\"Using device: {device}\", \"grey\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This code will automatically download them if it can't find them.\n",
    "audio_tokenizer = make_tokenizer(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We have different checkpoints for the single-speaker and two-speaker models\n",
    "# Set to True to load and run inference with the two-speaker model\n",
    "TWO_SPEAKER = False\n",
    "USE_PURE_AUDIO_ABLATION = False # We trained a base model with no text initialization at all. Toggle this to enable it.\n",
    "assert not (USE_PURE_AUDIO_ABLATION and TWO_SPEAKER) # We only have a single-speaker version of this model.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = get_hertz_dev_config(is_split=TWO_SPEAKER, use_pure_audio_ablation=USE_PURE_AUDIO_ABLATION)\n",
    "\n",
    "generator = model_config()\n",
    "generator = generator.eval().to(T.bfloat16).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_and_preprocess_audio(audio_path):\n",
    "    print_colored(\"Loading and preprocessing audio...\", \"blue\", bold=True)\n",
    "    # Load audio file\n",
    "    audio_tensor, sr = torchaudio.load(audio_path)\n",
    "    print_colored(f\"Loaded audio shape: {audio_tensor.shape}\", \"grey\")\n",
    "    \n",
    "    if TWO_SPEAKER:\n",
    "        if audio_tensor.shape[0] == 1:\n",
    "            print_colored(\"Converting mono to stereo...\", \"grey\")\n",
    "            audio_tensor = audio_tensor.repeat(2, 1)\n",
    "            print_colored(f\"Stereo audio shape: {audio_tensor.shape}\", \"grey\")\n",
    "    else:\n",
    "        if audio_tensor.shape[0] == 2:\n",
    "            print_colored(\"Converting stereo to mono...\", \"grey\")\n",
    "            audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)\n",
    "            print_colored(f\"Mono audio shape: {audio_tensor.shape}\", \"grey\")\n",
    "        \n",
    "    # Resample to 16kHz if needed\n",
    "    if sr != 16000:\n",
    "        print_colored(f\"Resampling from {sr}Hz to 16000Hz...\", \"grey\")\n",
    "        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)\n",
    "        audio_tensor = resampler(audio_tensor)\n",
    "        \n",
    "    # Clip to 5 minutes if needed\n",
    "    max_samples = 16000 * 60 * 5\n",
    "    if audio_tensor.shape[1] > max_samples:\n",
    "        print_colored(\"Clipping audio to 5 minutes...\", \"grey\")\n",
    "        audio_tensor = audio_tensor[:, :max_samples]\n",
    "\n",
    "    \n",
    "    print_colored(\"Audio preprocessing complete!\", \"green\")\n",
    "    return audio_tensor.unsqueeze(0)\n",
    "\n",
    "def display_audio(audio_tensor):\n",
    "    audio_tensor = audio_tensor.cpu().squeeze()\n",
    "    if audio_tensor.ndim == 1:\n",
    "        audio_tensor = audio_tensor.unsqueeze(0)\n",
    "    audio_tensor = audio_tensor.float()\n",
    "\n",
    "    # Make a waveform plot\n",
    "    plt.figure(figsize=(4, 1))\n",
    "    plt.plot(audio_tensor.numpy()[0], linewidth=0.5)\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "\n",
    "    # Make an audio player\n",
    "    display(Audio(audio_tensor.numpy(), rate=16000))\n",
    "    print_colored(f\"Audio ready for playback ↑\", \"green\", bold=True)\n",
    "    \n",
    "    \n",
    "\n",
    "# Our model is very prompt-sensitive, so we recommend experimenting with a diverse set of prompts.\n",
    "prompt_audio = load_and_preprocess_audio('./prompts/toaskanymore.wav')\n",
    "display_audio(prompt_audio)\n",
    "prompt_len_seconds = 3\n",
    "prompt_len = prompt_len_seconds * 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print_colored(\"Encoding prompt...\", \"blue\")\n",
    "with T.autocast(device_type='cuda', dtype=T.bfloat16):\n",
    "    if TWO_SPEAKER:\n",
    "        encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device))\n",
    "        encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device))\n",
    "        encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1)\n",
    "    else:\n",
    "        encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))\n",
    "print_colored(f\"Encoded prompt shape: {encoded_prompt_audio.shape}\", \"grey\")\n",
    "print_colored(\"Prompt encoded successfully!\", \"green\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_completion(encoded_prompt_audio, prompt_len, gen_len=None):\n",
    "    prompt_len_seconds = prompt_len / 8\n",
    "    print_colored(f\"Prompt length: {prompt_len_seconds:.2f}s\", \"grey\")\n",
    "    print_colored(\"Completing audio...\", \"blue\")\n",
    "    encoded_prompt_audio = encoded_prompt_audio[:, :prompt_len]\n",
    "    with T.autocast(device_type='cuda', dtype=T.bfloat16):\n",
    "        completed_audio_batch = generator.completion(\n",
    "            encoded_prompt_audio, \n",
    "            temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp))\n",
    "            use_cache=True,\n",
    "            gen_len=gen_len)\n",
    "\n",
    "        completed_audio = completed_audio_batch\n",
    "        print_colored(f\"Decoding completion...\", \"blue\")\n",
    "        if TWO_SPEAKER:\n",
    "            decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16())\n",
    "            decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16())\n",
    "            decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0)\n",
    "        else:\n",
    "            decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16())\n",
    "        print_colored(f\"Decoded completion shape: {decoded_completion.shape}\", \"grey\")\n",
    "\n",
    "    print_colored(\"Preparing audio for playback...\", \"blue\")\n",
    "\n",
    "    audio_tensor = decoded_completion.cpu().squeeze()\n",
    "    if audio_tensor.ndim == 1:\n",
    "        audio_tensor = audio_tensor.unsqueeze(0)\n",
    "    audio_tensor = audio_tensor.float()\n",
    "\n",
    "    if audio_tensor.abs().max() > 1:\n",
    "        audio_tensor = audio_tensor / audio_tensor.abs().max()\n",
    "\n",
    "    return audio_tensor[:, max(prompt_len*2000 - 16000, 0):]\n",
    "\n",
    "num_completions = 10\n",
    "print_colored(f\"Generating {num_completions} completions...\", \"blue\")\n",
    "for _ in range(num_completions):\n",
    "    completion = get_completion(encoded_prompt_audio, prompt_len, gen_len=20*8) # 20 seconds of generation\n",
    "    display_audio(completion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
